{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "51CTO课程频道：http://edu.51cto.com/lecturer/index/user_id-12330098.html<br>\n",
    "优酷频道：http://i.youku.com/sdxxqbf<br>\n",
    "微信公众号：深度学习与神经网络<br>\n",
    "Github：https://github.com/Qinbf<br>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting MNIST_data\\train-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data\\train-labels-idx1-ubyte.gz\n",
      "Extracting MNIST_data\\t10k-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data\\t10k-labels-idx1-ubyte.gz\n",
      "Iter 0,Testing Accuracy 0.8304\n",
      "Iter 1,Testing Accuracy 0.8702\n",
      "Iter 2,Testing Accuracy 0.8821\n",
      "Iter 3,Testing Accuracy 0.8884\n",
      "Iter 4,Testing Accuracy 0.894\n",
      "Iter 5,Testing Accuracy 0.8968\n",
      "Iter 6,Testing Accuracy 0.9011\n",
      "Iter 7,Testing Accuracy 0.9019\n",
      "Iter 8,Testing Accuracy 0.9034\n",
      "Iter 9,Testing Accuracy 0.9049\n",
      "Iter 10,Testing Accuracy 0.9057\n",
      "Iter 11,Testing Accuracy 0.9073\n",
      "Iter 12,Testing Accuracy 0.9081\n",
      "Iter 13,Testing Accuracy 0.9088\n",
      "Iter 14,Testing Accuracy 0.9098\n",
      "Iter 15,Testing Accuracy 0.9108\n",
      "Iter 16,Testing Accuracy 0.9118\n",
      "Iter 17,Testing Accuracy 0.9123\n",
      "Iter 18,Testing Accuracy 0.9127\n",
      "Iter 19,Testing Accuracy 0.9137\n",
      "Iter 20,Testing Accuracy 0.9138\n"
     ]
    }
   ],
   "source": [
    "#载入数据集\n",
    "mnist = input_data.read_data_sets(\"MNIST_data\",one_hot=True)\n",
    "\n",
    "#每个批次100张照片\n",
    "batch_size = 100\n",
    "#计算一共有多少个批次\n",
    "n_batch = mnist.train.num_examples // batch_size\n",
    "\n",
    "#定义两个placeholder\n",
    "x = tf.placeholder(tf.float32,[None,784])\n",
    "y = tf.placeholder(tf.float32,[None,10])\n",
    "\n",
    "#创建一个简单的神经网络，输入层784个神经元，输出层10个神经元\n",
    "W = tf.Variable(tf.zeros([784,10]))\n",
    "b = tf.Variable(tf.zeros([10]))\n",
    "prediction = tf.nn.softmax(tf.matmul(x,W)+b)\n",
    "\n",
    "#二次代价函数\n",
    "#square是求平方\n",
    "#reduce_mean是求平均值\n",
    "loss = tf.reduce_mean(tf.square(y-prediction))\n",
    "\n",
    "#使用梯度下降法来最小化loss，学习率是0.2\n",
    "train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)\n",
    "\n",
    "#初始化变量\n",
    "init = tf.global_variables_initializer()\n",
    "\n",
    "#结果存放在一个布尔型列表中\n",
    "correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置\n",
    "#求准确率\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))#cast是进行数据格式转换，把布尔型转为float32类型\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    #执行初始化\n",
    "    sess.run(init)\n",
    "    #迭代21个周期\n",
    "    for epoch in range(21):\n",
    "        #每个周期迭代n_batch个batch，每个batch为100\n",
    "        for batch in range(n_batch):\n",
    "            #获得一个batch的数据和标签\n",
    "            batch_xs,batch_ys =  mnist.train.next_batch(batch_size)\n",
    "            #通过feed喂到模型中进行训练\n",
    "            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})\n",
    "        \n",
    "        #计算准确率\n",
    "        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})\n",
    "        print(\"Iter \" + str(epoch) + \",Testing Accuracy \" + str(acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
